import argparse
import copy
import logging
import os
import random
from functools import partial
from operator import itemgetter

# sys.path.append('./')
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_add_pool
from tqdm import tqdm

from chem.data_process import process_idx, set_random_seed
from chem.dataloader import DataLoaderMasking1, DataLoaderMasking2
from chem.loader import MoleculeDataset1
from loader_we import MoleculeDataset2
from model import GNN
from splitters import scaffold_split, random_split


class CosineDecayScheduler:
    def __init__(self, max_val, warmup_steps, total_steps):
        self.max_val = max_val
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def get(self, step):
        if step < self.warmup_steps:
            return self.max_val * step / self.warmup_steps
        elif self.warmup_steps <= step <= self.total_steps:
            return self.max_val * (1 + np.cos((step - self.warmup_steps) * np.pi /
                                              (self.total_steps - self.warmup_steps))) / 2
        else:
            raise ValueError('Step ({}) > total number of steps ({}).'.format(step, self.total_steps))


def sce_loss(x, y, alpha=1):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)
    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
    loss = loss.mean()
    return loss


def mask(g, mask_rate=0.5):
    num_nodes = g.x.shape[0]
    perm = torch.randperm(num_nodes, device=g.x.device)
    num_mask_nodes = int(mask_rate * num_nodes)
    mask_nodes = perm[:num_mask_nodes]
    return mask_nodes


class D_CG(nn.Module):
    # rate学习率
    def __init__(self, num_layer, pre_dim, emb_dim, JK, drop_ratio=0, mask_rate=0, random_mask=False):
        super(D_CG, self).__init__()
        self.drop_rate = drop_ratio
        self.online_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=self.drop_rate,
                                  pre=True)
        self.target_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=0,
                                  pre=True)
        self.mask_rate = mask_rate
        self.enc_mask_token = nn.Parameter(torch.zeros(1, pre_dim))
        self.criterion = self.setup_loss_fn("sce", 1)
        self.pool_mean = global_mean_pool
        self.pool_add = global_add_pool
        self.random_mask = random_mask
        for param in self.target_encoder.parameters():
            param.requires_grad = False

    def setup_loss_fn(self, loss_fn, alpha_l):
        if loss_fn == "mse":
            criterion = nn.MSELoss()
        elif loss_fn == "sce":
            criterion = partial(sce_loss, alpha=alpha_l)
        else:
            raise NotImplementedError
        return criterion

    def trainable_parameters(self):
        r"""Returns the parameters that will be updated via an optimizer."""
        return list(self.online_encoder.parameters())

    @torch.no_grad()
    def update_target_network(self, mm):
        r"""Performs a momentum update of the target network's weights.
        Args:
            mm (float): Momentum used in moving average update.
        """
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)

    def forward(self, graph_batch, device):
        graph_batch = graph_batch.to(device)
        h1 = self.online_encoder(graph_batch.x.to(device), graph_batch.edge_index, graph_batch.edge_attr)
        with torch.no_grad():
            h2 = self.target_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
        loss_2 = self.criterion(h1, h2.detach())
        return loss_2

    def get_embed(self, graph_batch, device):
        with torch.no_grad():
            graph_batch = graph_batch.to(device)
            h1 = self.online_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
            h = self.pool_add(h1, graph_batch.batch)
        return h.detach()


class CG(nn.Module):
    # rate学习率
    def __init__(self, num_layer, pre_dim, emb_dim, JK, drop_ratio, mask_rate, random_mask=False):
        super(CG, self).__init__()
        self.drop_rate = drop_ratio
        self.online_encoder = GNN(num_layer, pre_dim=pre_dim, emb_dim=emb_dim, JK=JK, drop_ratio=self.drop_rate,
                                  pre=True)
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.target_encoder.reset_parameters()
        self.mask_rate = mask_rate
        self.enc_mask_token = nn.Parameter(torch.zeros(1, pre_dim))
        self.criterion = self.setup_loss_fn("sce", 1)
        self.pool_mean = global_mean_pool
        self.pool_add = global_add_pool
        self.random_mask = random_mask
        for param in self.target_encoder.parameters():
            param.requires_grad = False

    def setup_loss_fn(self, loss_fn, alpha_l):
        if loss_fn == "mse":
            criterion = nn.MSELoss()
        elif loss_fn == "sce":
            criterion = partial(sce_loss, alpha=alpha_l)
        else:
            raise NotImplementedError
        return criterion

    def trainable_parameters(self):
        r"""Returns the parameters that will be updated via an optimizer."""
        return list(self.online_encoder.parameters())

    @torch.no_grad()
    def update_target_network(self, mm):
        r"""Performs a momentum update of the target network's weights.
        Args:
            mm (float): Momentum used in moving average update.
        """
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)

    def forward(self, graph_batch, device):
        if self.random_mask:
            graph_batch = graph_batch.to(device)
            mask_nodes = mask(graph_batch, mask_rate=0.1)
            x = graph_batch.x.clone()
            x[mask_nodes] = 0.0
            x[mask_nodes] += self.enc_mask_token.long().to(graph_batch.x.device)
            h1 = self.online_encoder(x.to(device), graph_batch.edge_index, graph_batch.edge_attr)
            with torch.no_grad():
                h2 = self.target_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
            loss_2 = self.criterion(h1[mask_nodes], h2[mask_nodes].detach())
        else:
            graph_batch = graph_batch.to(device)
            x = graph_batch.x.clone().float().to(device)
            mask_nodes = graph_batch.mask_nodes.long()
            x[mask_nodes] = 0.0
            x[mask_nodes] += self.enc_mask_token
            h1 = self.online_encoder(x, graph_batch.edge_index, graph_batch.edge_attr)
            with torch.no_grad():
                h2 = self.target_encoder(graph_batch.x.float(), graph_batch.edge_index, graph_batch.edge_attr)
            h1_mask_pool = self.pool_mean(h1[mask_nodes], graph_batch.batch[mask_nodes])
            h2_mask_pool = self.pool_mean(h2[mask_nodes].detach(), graph_batch.batch[mask_nodes])
            loss_2 = self.criterion(h1_mask_pool, h2_mask_pool)
        return loss_2

    def get_embed(self, graph_batch, device):
        with torch.no_grad():
            graph_batch = graph_batch.to(device)
            h1 = self.online_encoder(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
            h = self.pool_add(h1, graph_batch.batch)
        return h.detach()


class LogReg(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.graph_pred_linear = torch.nn.Linear(in_dim, out_dim)

    def forward(self, h):
        z = self.graph_pred_linear(h)
        return z


def computer_roc(log, model, loader, device):
    y_true = []
    y_scores = []
    log.eval()
    for step, batch in enumerate(loader):
        with torch.no_grad():
            z = model.get_embed(batch, device)
            pred = log(z)
        y_true.append(batch.y.view(pred.shape))
        y_scores.append(pred)
    y_true = torch.cat(y_true, dim=0).cpu().detach().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().detach().numpy()
    roc_list = []
    for i in range(y_true.shape[1]):
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
            is_valid = y_true[:, i] ** 2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
    if len(roc_list) < y_true.shape[1]:
        print("some target is missing")
        print("missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1]))
    roc = sum(roc_list) / len(roc_list)
    return roc


def evaluation(model, writer, loader, train_loader, val_loader, test_loader, num_tasks, device,
               num_test_epoch=100):
    xent = nn.BCEWithLogitsLoss(reduction="none")
    log = LogReg(300, num_tasks).to(device)
    opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
    loss_list, train_roc, val_roc, test_roc = [], [], [], []
    pdar_2 = tqdm(range(num_test_epoch))
    best_val_roc = 0
    test_best_roc = []
    model.eval()
    for epoch in pdar_2:
        total_loss = []
        for step, batch in enumerate(loader):
            with torch.no_grad():
                z = model.get_embed(batch, device)
            # train
            log.train()
            pred = log(z)
            y = batch.y.view(pred.shape).to(torch.float64)
            is_valid = y ** 2 > 0
            loss_mat = xent(pred.double(), (y + 1) / 2)
            loss_mat = torch.where(is_valid, loss_mat,
                                   torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
            opt.zero_grad()
            loss = torch.sum(loss_mat) / torch.sum(is_valid)
            total_loss.append(loss.item())

            loss.backward()
            opt.step()

        avg_loss = sum(total_loss) / len(total_loss)
        logging.info("Epoch: {}, Loss: {}".format(epoch, round(avg_loss, 3)))
        # roc_of_train
        roc_train = computer_roc(log, model, train_loader, device)
        # roc_of_val
        roc_val = computer_roc(log, model, val_loader, device)
        # roc_of_test
        roc_test = computer_roc(log, model, test_loader, device)
        best_test = 0
        if roc_val >= best_val_roc:
            if roc_val == best_val_roc and best_test < roc_test:
                best_test = roc_test
                test_best_roc.append([best_val_roc, roc_test, epoch])
                continue
            best_test = roc_test
            best_val_roc = roc_val
            test_best_roc.append([best_val_roc, roc_test, epoch])
        writer.add_scalar("train_loss", avg_loss, epoch)
        writer.add_scalar("train_roc", roc_train, epoch)
        writer.add_scalar("val_roc", roc_val, epoch)
        writer.add_scalar("test_roc", roc_test, epoch)
        loss_list.append(avg_loss)
        train_roc.append(roc_train)
        val_roc.append(roc_val)
        test_roc.append(roc_test)
        pdar_2.set_description(
            f"Epoch: {epoch}, Loss: {round(avg_loss, 3)}, Train_roc: {round(roc_train, 3)},Val_roc: {round(roc_val, 3)},Test_roc: {round(roc_test, 3)}")
    return loss_list, train_roc, val_roc, test_roc, test_best_roc


def train(args, model, loader, train_loader, val_loader, test_loader, device):
    dataname = args.dataset
    save_name = args.save_name
    num_epochs = args.epochs
    num_tasks = args.num_tasks
    num_test_epoch = args.test_epochs
    writer = SummaryWriter(log_dir=f"./logs/{dataname}/{args.out_file}/{save_name}")
    lr = args.lr
    mm_l = args.mm
    lr_scheduler = CosineDecayScheduler(lr, 100, num_epochs)
    mm_scheduler = CosineDecayScheduler(1 - mm_l, 0, num_epochs)
    optimizer = Adam(model.trainable_parameters(), lr=lr, weight_decay=1e-4)
    print("开始进行训练......")
    best_loss = float('inf')
    pdar = tqdm(range(1, num_epochs + 1))
    for epoch in pdar:
        model.train()
        # update learning rate
        lr = lr_scheduler.get(epoch - 1)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        mm = 1 - mm_scheduler.get(epoch - 1)
        losses = 0
        for gs in loader:
            loss = model(gs, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.update_target_network(mm)
            losses += loss.item()
        avg_loss = losses / len(loader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"./result/{dataname}/{args.out_file}/{save_name}/{dataname}_best_model.pt")
        pdar.set_description(f"Epoch: {epoch}, Loss: {round(losses / len(train_loader), 5)}")
        writer.add_scalar("ContrastiveTrain_Loss", round(losses / len(train_loader), 5), epoch)

    loss_list, train_roc, val_roc, test_roc, test_best_roc = evaluation(model, writer, loader, train_loader, val_loader,
                                                                        test_loader, num_tasks, device,
                                                                        num_test_epoch)
    # 打开文件
    with open(f"./result/{dataname}/{args.out_file}/{save_name}/result_of_train.txt", 'w') as f:
        # 写入数据
        for i in range(len(loss_list)):
            f.write(f'{loss_list[i]:.4f},{train_roc[i]:.4f},{val_roc[i]:.4f},{test_roc[i]:.4f}\n')
    with open(f"./result/{dataname}/{args.out_file}/{save_name}/result_of_roc_test.txt", 'w') as f:
        # 写入数据
        for i in range(len(test_best_roc)):
            f.write("val_roc,test_roc,epoch\n")
            val_roc, test_roc, epoch = test_best_roc[i][0], test_best_roc[i][1], test_best_roc[i][2],
            f.write(f'{round(val_roc, 4)},{round(test_roc, 4)},{epoch}\n')
    return test_best_roc[-1]


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    ## cessary args
    parser.add_argument('--dataset', type=str, default='bbbp')
    parser.add_argument('--input_model_file', type=str,
                        default='./result/pre_training/Method_3_0.25dropout/zinc_standard_agent_best_model.pth',
                        help='filename to read the model (if there is any)')
    parser.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    parser.add_argument('--save_name', type=str, default='exp1')
    parser.add_argument('--batch_size', type=int, default=256, help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
    parser.add_argument('--pre_dim', type=int, default=2)
    parser.add_argument('--test_epochs', type=int, default=100, help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
    parser.add_argument('--num_layer', type=int, default=3, help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.25, help='dropout ratio (default: 0.5)')
    parser.add_argument('--mask_ratio', type=float, default=0.0, help='mask ratio (default: 0.5)')
    parser.add_argument('--JK', type=str, default="last")
    parser.add_argument('--split', type=str, default="scaffold", help="random or scaffold or random_scaffold")
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataset loading')
    # 默认1为不对官能团进行任何处理，2是保留长的官能团以及保留去重之后短的官能团，3是保留短的官能团，不遮蔽长的官能团
    parser.add_argument('--hander_mt', type=int, default=3, help='1,2,3')
    parser.add_argument('--mm', type=float, default=0.999, help='0.99,0.999,0.999')
    parser.add_argument('--pre_mask', type=bool, default=True, help='number of workers for dataset loading')

    args = parser.parse_args()
    print(args)
    args_Dict = args.__dict__
    seeds = args.seeds
    dataname = args.dataset
    split = args.split
    batch_size = args.batch_size
    num_workers = args.num_workers

    if dataname == "tox21":
        num_tasks = 12
    elif dataname == "hiv":
        num_tasks = 1
    elif dataname == "muv":
        num_tasks = 17
    elif dataname == "bace":
        num_tasks = 1
    elif dataname == "bbbp":
        num_tasks = 1
    elif dataname == "toxcast":
        num_tasks = 617
    elif dataname == "sider":
        num_tasks = 27
    elif dataname == "clintox":
        num_tasks = 2
    elif dataname == "ptc_mr":
        num_tasks = 1
    elif dataname == "mutag":
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    device = torch.device("cuda:0")
    dataset = MoleculeDataset2("./dataset/" + dataname, dataset=dataname, mask_ratio=args.mask_ratio,
                               hander_mt=args.hander_mt)
    if split == "scaffold":
        smiles_list = pd.read_csv('./dataset/' + dataname + '/processed/smiles.csv', header=None)[
            0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0,
                                                                    frac_train=0.8,
                                                                    frac_valid=0.1, frac_test=0.1)
        print("数据集划分方法为: scaffold")

    elif split == "random":
        smiles_list = pd.read_csv('./dataset/' + dataname + '/processed/smiles.csv', header=None)[
            0].tolist()
        train_dataset, valid_dataset, test_dataset, (train_smiles, valid_smiles,
                                                     test_smiles) = random_split(dataset, task_idx=None, null_value=0,
                                                                                 frac_train=0.8, frac_valid=0.1,
                                                                                 frac_test=0.1,
                                                                                 seed=0,
                                                                                 smiles_list=smiles_list)
        print("数据集划分方法为: scaffold")
    else:
        raise ValueError("Invalid split option.")

    loader = DataLoaderMasking2(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    train_loader = DataLoaderMasking2(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    val_loader = DataLoaderMasking2(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoaderMasking2(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    args.num_tasks = num_tasks

    roc_list = []
    test_roc_list = []

    j = 1
    while True:
        out_file = 'test'
        out_file = out_file + str(j)
        if not os.path.exists(f"./result/{dataname}/{out_file}"):
            os.makedirs(f"./result/{dataname}/{out_file}")
            args.out_file = out_file
            break
        j = j + 1
    with open(f"./result/{dataname}/{out_file}/args.txt", 'w', encoding='utf-8') as f:
        for eachArg, value in args_Dict.items():
            f.writelines(eachArg + ' : ' + str(value) + '\n')
    for i, seed in enumerate(seeds):
        folder_name = "exp"
        j = 1
        while True:
            new_folder_name = folder_name + str(j)
            if not os.path.exists(f"./result/{dataname}/{out_file}/{new_folder_name}"):
                os.makedirs(f"./result/{dataname}/{out_file}/{new_folder_name}")
                args.save_name = new_folder_name
                break
            j += 1
        print(f"####### Run {i} for seed {seed}")
        set_random_seed(seed)

        model = CG(args.num_layer, args.pre_dim, args.emb_dim, "last", args.dropout_ratio, mask_rate=args.mask_ratio,
                   random_mask=True).to(device)
        if dataname not in ['mutag', 'ptc_mr']:
            model.load_state_dict(torch.load(args.input_model_file))
        test_roc = train(args, model, loader, train_loader, val_loader, test_loader, device)
        roc_list.append(test_roc)
        test_roc_list.append(test_roc[1])
    final_roc, final_roc_std = np.mean(test_roc_list), np.std(test_roc_list)
    print(final_roc, final_roc_std)
    with open(f"./result/{dataname}/{args.out_file}/{args.save_name}/result_of_train.txt", "w") as f:
        for i in roc_list:
            val_roc, test_roc, epoch = i[0], i[1], i[2]
            f.write(f'{round(val_roc, 4)},{round(test_roc, 4)},{epoch}\n')
        f.write("# final_acc:" + str(final_roc) + "±" + str(final_roc_std) + "\n")


if __name__ == '__main__':
    main()

    # dataname = 'sider'
    # dataset = MoleculeDataset2("./dataset/" + dataname, dataset=dataname, mask_ratio=0.24,hander_mt=1)
    # print(dataset[0])
    # loader = DataLoaderMasking2(dataset, batch_size=256, shuffle=False, num_workers=0)
    # z = 0
    # for i in loader:
    #     z += len(i.mask_nodes) / i.x.size(0)
    #     print(len(i.mask_nodes) / i.x.size(0))
    # print(z / len(loader), 'pj')
